import torch
import torch.nn as nn
from torch import float16

from ..layers.layer_utils import Conv2d_tucker_fixed, Conv2d_tucker_adaptive


##########################################################################################################################################
# net metric utils
##########################################################################################################################################
@torch.no_grad()
def full_count_params(NN, count_bias=False, with_grads=False, count_linear=False):
    """
    Function that counts the total number of parameters needed for a full rank version of NN
    INPUTS:
    NN: neural network
    count_bias : flag variable, True if the biases are to be included in the total or not
    OUTPUTS:
    total_params : total number of parameters in the full rank version of NN
    """

    total_params = 0

    if isinstance(NN, nn.Linear) and count_linear:
        total_params += 2 * NN.in_features * NN.out_features if with_grads else \
            NN.in_features * NN.out_features

    elif isinstance(NN, Conv2d_tucker_fixed) or isinstance(NN, Conv2d_tucker_adaptive) or isinstance(NN, nn.Conv2d):
        total_params += 2 * NN.kernel_size[0] * NN.kernel_size[1] * NN.in_channels * NN.out_channels if with_grads else \
            NN.kernel_size[0] * NN.kernel_size[1] * NN.in_channels * NN.out_channels
        if count_bias and NN.bias is not None:
            total_params += 2 * len(NN.bias) if with_grads else len(NN.bias)
    else:
        for name, child in NN.named_children():
            total_params += full_count_params(child, count_bias=count_bias, with_grads=with_grads,
                                              count_linear=count_linear)

    return total_params


@torch.no_grad()
def get_dynamic_ranks(NN):
    """
        Function that prints the current dynamic ranks of the network
        INPUTS:
        NN: neural network
        count_bias : flag variable, True if the biases are to be included in the total or not
        OUTPUTS:
        total_params : total number of parameters in the full rank version of NN
        """

    ranks = []
    if hasattr(NN, 'dynamic_rank'):
        ranks.append(NN.dynamic_rank)
    else:
        for name, child in NN.named_children():
            ranks += get_dynamic_ranks(child)
    return ranks


@torch.no_grad()
def count_params(T, with_grads=False):
    """
    function to count number of parameters inside a tensor
    INPUT:
    T : torch.tensor or None
    output:
    number of parameters contained in T
    """

    if len(T.shape) > 1:
        if with_grads:
            return 2 * int(torch.prod(torch.tensor(T.shape)))
        else:
            return int(torch.prod(torch.tensor(T.shape)))
    elif T == None:
        return 0
    else:
        if with_grads:
            return 2 * T.shape[0]
        else:
            return T.shape[0]


@torch.no_grad()
def count_params_train(NN, count_bias=False, with_grads=False, count_linear=False):
    """
    function to count the parameters in the train phase

    INPUTS:
    NN : neural network
    count_bias : flag variable, True if the biases are to be included in the total or not
    """

    total_params = 0

    if isinstance(NN, Conv2d_tucker_adaptive):
        d = len(NN.C.shape)
        for i in range(d):
            total_params += count_params(NN.Ks[i][:, :NN.dynamic_rank[i]], with_grads)
            total_params += count_params(NN.Us[i][:, :NN.dynamic_rank[i]])
            total_params += count_params(NN.U_hats[i][:, :2 * NN.dynamic_rank[i]])
            total_params += count_params(NN.M_hats[i][:2 * NN.dynamic_rank[i], :NN.dynamic_rank[i]])
            total_params += count_params(NN.Qst[i][:NN.dynamic_rank[i]])
        total_params += count_params(
            NN.C[:2 * NN.dynamic_rank[0], :2 * NN.dynamic_rank[1], :2 * NN.dynamic_rank[2],
            :2 * NN.dynamic_rank[3]], with_grads)
        if count_bias:
            total_params += count_params(NN.bias)
    elif isinstance(NN, Conv2d_tucker_fixed):
        d = len(NN.C.shape)
        for i in range(d):
            total_params += count_params(NN.Ks[i][:, :NN.dynamic_rank[i]], with_grads)
            total_params += count_params(NN.Us[i][:, :NN.dynamic_rank[i]])
            total_params += count_params(NN.M_hats[i][:NN.dynamic_rank[i], :NN.dynamic_rank[i]])
            total_params += count_params(NN.Qst[i][:NN.dynamic_rank[i]])
        total_params += count_params(
            NN.C[:NN.dynamic_rank[0], :NN.dynamic_rank[1], :NN.dynamic_rank[2], :NN.dynamic_rank[3]],
            with_grads)
        if count_bias:
            total_params += count_params(NN.bias)
    elif isinstance(NN, nn.Conv2d):
        for n, p in NN.named_parameters():
            if 'bias' not in n:
                total_params += count_params(p, with_grads)  # add with grads
            elif 'bias' in n and count_bias:
                total_params += count_params(p)
    elif isinstance(NN, nn.Linear) and count_linear:
        for n, p in NN.named_parameters():
            if 'bias' not in n:
                total_params += count_params(p, with_grads)  # add with grads
            elif 'bias' in n and count_bias:
                total_params += count_params(p)

    for name, child in NN.named_children():
        total_params += count_params_train(child, count_bias=count_bias, with_grads=with_grads,
                                           count_linear=count_linear)

    return total_params


@torch.no_grad()
def count_params_test(NN, count_bias=False, count_linear=False):
    """
    function to count the parameters in the test phase

    INPUTS:
    NN : neural network
    count_bias : flag variable, True if the biases are to be included in the total or not
    """

    total_params = 0

    if isinstance(NN, Conv2d_tucker_adaptive) or isinstance(NN, Conv2d_tucker_fixed):
        for i in range(len(NN.C.shape)):
            total_params += count_params(NN.Us[i][:, :NN.dynamic_rank[i]])
        total_params += count_params(
            NN.C[:NN.dynamic_rank[0], :NN.dynamic_rank[1], :NN.dynamic_rank[2], :NN.dynamic_rank[3]])
        if count_bias:
            total_params += count_params(NN.bias)
    elif isinstance(NN, nn.Conv2d):
        for n, p in NN.named_parameters():
            if 'bias' not in n:
                total_params += count_params(p)
            elif 'bias' in n and count_bias:
                total_params += count_params(p)
    elif isinstance(NN, nn.Linear) and count_linear:
        for n, p in NN.named_parameters():
            if 'bias' not in n:
                total_params += count_params(p)
            elif 'bias' in n and count_bias:
                total_params += count_params(p)

    for name, child in NN.named_children():
        total_params += count_params_test(child, count_bias=count_bias, count_linear=count_linear)
    return total_params


@torch.no_grad()
def max_dist_M_I(NN):
    """

    """

    norms = []

    for l in NN.lr_model:

        if hasattr(l, 'dlrt') and l.dlrt and 'conv' in str(l).lower():

            if norms == []:
                norms = torch.tensor([torch.norm(
                    M[:l.dynamic_rank[i], :l.dynamic_rank[i]] - torch.eye(l.dynamic_rank[i], m=l.dynamic_rank[i]).to(
                        M.device)) for i, M in enumerate(l.M_hats)])

            norms = torch.vstack([norms, torch.tensor([torch.norm(
                M[:l.dynamic_rank[i], :l.dynamic_rank[i]] - torch.eye(l.dynamic_rank[i], m=l.dynamic_rank[i]).to(
                    M.device)) for i, M in enumerate(l.M_hats)])])

    return [float(el) for el in torch.mean(norms, dim=0)]


def accuracy(outputs, labels):
    return torch.mean(torch.tensor(torch.argmax(outputs.detach(), axis=1) == labels, dtype=float16))


def accuracy_top_k(output, target, topk=(1,)):
    """Computes the non normalized version of topk accuracies"""
    maxk = max(topk)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    # res = []
    res = dict()
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        # res.append(correct_k)   #correct_k.mul_(100.0 / batch_size)
        res[k] = correct_k
    return res


##############################################################################################################################
# training loop utils
##############################################################################################################################

def check_params_change(NN, op):
    print(f'operation: {str(op)}')

    dic_before = {}
    for n, p in NN.lr_model.named_parameters():
        dic_before[n] = p.clone()

    op()

    for n, p in NN.lr_model.named_parameters():
        print(f'name n: {n} param changed: {not torch.allclose(p, dic_before[n])}')

    input()
    print('\n')
    print('=' * 100)
